import itertools

import numpy as np

from src.clustering.base_distances_bis import (
    norm_hamming,
    jaccard,
    rms_jaccard,
    geom_jaccard,
    rms_hamming,
    geom_hamming,
)
from src.converters.convert_votes_to_vecs import (
    convert_profiles_to_vec_profiles
)
from src.imports.import_pabulib import import_pabulib_files_from_folder

base_distances = {
    'jaccard': jaccard,
    'geom_jaccard': geom_jaccard,
    'norm_hamming': norm_hamming,
    'rms_jaccard': rms_jaccard,
    'geom_hamming': geom_hamming,
    'rms_hamming': rms_hamming,
}

if __name__ == "__main__":

    targets = ['warszawa_2023_districts','warszawa_2024_districts',
               'krakow_2023_districts', 'krakow_2024_districts',
                'lodz_2023_districts', 'lodz_2024_districts']

    for target in targets:

        instances, profiles = import_pabulib_files_from_folder(f'data/pabulib/{target}')

        vec_profiles, names = convert_profiles_to_vec_profiles(profiles)


        for instance_id in instances:

            P = vec_profiles[instance_id]
            P = np.array(P)
            P = np.transpose(P)

            num_projects = len(vec_profiles[instance_id])

            ALL_INTRA = {base_distance: {} for base_distance in base_distances}
            ALL_INTER = {base_distance: {} for base_distance in base_distances}
            ALL_CLOSEST = {base_distance: {} for base_distance in base_distances}

            for base_distance in base_distances:

                intra_party_distances = []
                inter_party_distances = []
                closest_friend_ratio = [0 for _ in range(num_projects)]

                id_0 = names[instance_id][0]
                if 'categories' not in instances[instance_id].project_meta[id_0]:
                    break

                # get list of all categories
                categories = dict()
                for i in range(num_projects):
                    i_id = names[instance_id][i]
                    i_cat = instances[instance_id].project_meta[i_id]['categories']
                    for c in i_cat:
                        if c not in set(categories.keys()):
                            categories[c] = 1
                        else:
                            categories[c] += 1


                cat_mapping = {}
                for i in range(num_projects):
                    i_id = names[instance_id][i]
                    i_cat = instances[instance_id].project_meta[i_id]['categories']
                    cat_mapping[i_id] = i_cat

                all_d = []
                for i, j in itertools.combinations(range(num_projects), 2):
                    i_cat = cat_mapping[names[instance_id][i]]
                    j_cat = cat_mapping[names[instance_id][j]]

                    d = base_distances[base_distance](P[:, i], P[:, j])
                    all_d.append(d)

                avg_d = np.mean(all_d)
                # print(avg_d)


                for category in categories:

                    if categories[category] == 1:
                        continue

                    for i, j in itertools.combinations(range(num_projects), 2):

                        i_cat = cat_mapping[names[instance_id][i]]
                        j_cat = cat_mapping[names[instance_id][j]]

                        if category in i_cat and category in j_cat:
                            d = base_distances[base_distance](P[:, i], P[:, j])
                            intra_party_distances.append(d)

                        elif (category in i_cat and category not in j_cat) \
                                or (category not in i_cat and category in j_cat):
                            d = base_distances[base_distance](P[:, i], P[:, j])
                            inter_party_distances.append(d)


                    avg_intra_party_distance = np.mean(intra_party_distances) / avg_d
                    avg_inter_party_distance = np.mean(inter_party_distances) / avg_d
                    # closest_friend_ratio = sum(closest_friend_ratio) / num_projects
                    ALL_INTRA[base_distance][category] = avg_intra_party_distance
                    ALL_INTER[base_distance][category] = avg_inter_party_distance
                    ALL_CLOSEST[base_distance][category] = 0

                # export to csv files the intra, inter and diff distances
                with open(f"output/pabulib/cat/{target}/{base_distance}_distances_{instance_id}.csv", "w") as csv_file:
                    csv_file.write("category;intra;inter;closest\n")
                    for cat in categories:
                        if categories[cat] == 1:
                            continue
                        csv_file.write(f"{cat};{ALL_INTRA[base_distance][cat]};{ALL_INTER[base_distance][cat]};{ALL_CLOSEST[base_distance][cat]}\n")





